import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler
import numpy as np
import copy
import random
#import networkx as nx
#import matplotlib.pyplot as plt

bootstrap_bs = 5
BatchNum = 100
#NumSample = 50
NumSample = 1
seed = 2
npseed = 6
torchseed = 5

def TakeGradient(x, y, model, criterion):
    x = copy.deepcopy(x)
    y = copy.deepcopy(y)
    #model = copy.deepcopy(test_model)
    FullGradient = torch.tensor([])
    y_pred = model(x)
    loss = criterion(y_pred, y)
    #grad = torch.autograd.grad(loss, x)
    for p in model.parameters():
        if p.requires_grad:
            #print(torch.autograd.grad(loss, p))
            FullGradient = torch.cat((FullGradient, torch.autograd.grad(loss, p, retain_graph=True)[0].view(-1, 1)), 0)
    #print(FullGradient)
    return FullGradient.view(1, -1)

def LossScaledTrace(test_model, inputs, labels, d, N_train, bs=bootstrap_bs):
    torch.manual_seed(seed)
    model = copy.deepcopy(test_model)
    d = sum(p.numel() for p in model.parameters())
    FullGradient = torch.tensor([])
    #Gradients = torch.zeros(BatchNum, d * 2)
    CovarianceMatrix = torch.zeros(d, d)
    Hessian = torch.zeros(d, d)
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1)

    # Converting inputs and labels to Variable
    if False:#torch.cuda.is_available():
        inputs = Variable(torch.from_numpy(inputs).cuda())
        labels = Variable(torch.from_numpy(labels).cuda())
    else:
        inputs = Variable(torch.from_numpy(inputs))
        labels = Variable(torch.from_numpy(labels))

    # Create DataLoader
    dataset = TensorDataset(inputs, labels)
    sampler = RandomSampler(dataset, replacement=True, num_samples=bootstrap_bs * BatchNum)
    batch_loader = DataLoader(dataset=dataset, batch_size=bs, sampler=sampler)
    single_loader = DataLoader(dataset=dataset, batch_size=1)
    #full_loader = DataLoader(dataset=dataset, batch_size=N_train)

    # Compute the full gradient
    optimizer.zero_grad()
    FullLoss = criterion(model(inputs), labels)
    FullBatchLoss = FullLoss.item()
    FullLoss.backward()
    #FullGradient = torch.cat((model.linear.weight.grad.clone(), model.linearminus.weight.grad.clone()), 1)
    for p in model.parameters():
        if p.requires_grad:
            FullGradient = torch.cat((FullGradient, p.grad.view(-1, 1)), 0)
    #print("The full gradient is {}".format(FullGradient))
    optimizer.zero_grad()

    # Compute the Hessian and the Covariance Matrix
    optimizer.zero_grad()
    for idx, (x, y) in enumerate(single_loader):
    #for idx, (x, y) in enumerate(batch_loader):
        loss = criterion(model(x), y)
        loss.backward()
        #LossGradient = torch.cat((model.linear.weight.grad.clone(), model.linearminus.weight.grad.clone()), 1)
        LossGradient = torch.tensor([])
        for p in model.parameters():
            if p.requires_grad:
                LossGradient = torch.cat((LossGradient, p.grad.view(-1, 1)), 0)
        optimizer.zero_grad()
        function = model(x)
        function.backward()
        #FunctionGradient = torch.cat((model.linear.weight.grad.clone(), model.linearminus.weight.grad.clone()), 1)
        FunctionGradient = torch.tensor([])
        for p in model.parameters():
            if p.requires_grad:
                FunctionGradient = torch.cat((FunctionGradient, p.grad.view(-1, 1)), 0)
        #print(FunctionGradient.shape)
        optimizer.zero_grad()
        #Hessian += torch.mm(torch.reshape(Gradient, (d * 2, 1)), torch.reshape(Gradient, (1, d * 2))) / max([(loss.item() * 2), 10e-15])
        Hessian += torch.mm(torch.reshape(FunctionGradient, (d, 1)), torch.reshape(FunctionGradient, (1, d)))
        CovarianceMatrix += torch.mm(torch.reshape(LossGradient, (d, 1)), torch.reshape(LossGradient, (1, d)))
        optimizer.zero_grad()
        #print(torch.matmul(Gradients[idx], torch.transpose(Gradients[idx])).size())
    Hessian = Hessian / N_train
    CovarianceMatrix = CovarianceMatrix / N_train
    CovarianceMatrix -= torch.mm(torch.reshape(FullGradient, (d, 1)), torch.reshape(FullGradient, (1, d)))
    #print('The dim of the Hessian is {}'.format(str(Hessian.shape)))
    #return torch.trace(torch.mm(Hessian, CovarianceMatrix)) / max([(FullBatchLoss * 2), 10e-15]), \
    return torch.trace(torch.mm(Hessian, CovarianceMatrix)), \
           np.sqrt(torch.trace(torch.mm(Hessian, Hessian))), torch.trace(Hessian), torch.trace(CovarianceMatrix), Hessian, CovarianceMatrix#, torch.trace(torch.mm(torch.inverse(Hessian), CovarianceMatrix))
    #return torch.trace(torch.mm(Hessian, CovarianceMatrix)) / np.sqrt(torch.trace(torch.mm(Hessian, Hessian)) * torch.trace(torch.mm(CovarianceMatrix, CovarianceMatrix))), \
    #       torch.trace(torch.mm(Hessian, Hessian)), torch.trace(Hessian),

def NeighborLossScaledTrace(test_model, inputs, labels, d, N_train, radius):
    torch.manual_seed(torchseed)
    np.random.seed(npseed)
    random.seed(seed)
    #dim = sum(p.numel() for p in test_model.parameters())
    ProductTrs = np.zeros(NumSample)
    Frobenius = np.zeros(NumSample)
    HessianTrs = np.zeros(NumSample)
    CovarianceTrs = np.zeros(NumSample)
    #TICs = np.zeros(NumSample)
    for i in range(NumSample):
        newmodel = copy.deepcopy(test_model)
        '''
        for j in range(d):
            pert = (random.random() - 0.5) * radius
            with torch.no_grad():
                newmodel.linear.weight[0, j] += torch.tensor(pert)
            pert = (random.random() - 0.5) * radius
            with torch.no_grad():
                newmodel.linearminus.weight[0, j] += torch.tensor(pert)
            #newmodel.linear.weight[j] += torch.tensor(pert)
        '''
        '''
        for p in newmodel.parameters():
            if p.requires_grad:
                print(p.size())
        '''
        ProductTrs[i], Frobenius[i], HessianTrs[i], CovarianceTrs[i], Hessian, CovarianceMatrix = LossScaledTrace(newmodel, inputs, labels, d, N_train)
        #newmodel = copy.deepcopy(test_model)

    return ProductTrs.mean(), Frobenius.mean(), HessianTrs.mean(), CovarianceTrs.mean(), Hessian, CovarianceMatrix

def LimitNeighborLossScaledTrace(test_model, inputs, labels, d, N_train, radiuses):
    ProductTraces = np.zeros(len(radiuses))
    Frobeniuses = np.zeros(len(radiuses))
    HessianTraces = np.zeros(len(radiuses))
    CovarianceTraces = np.zeros(len(radiuses))
    #TICs = np.zeros(len(radiuses))
    for i in range(len(radiuses)):
        ProductTraces[i], Frobeniuses[i], HessianTraces[i], CovarianceTraces[i], Hessian, CovarianceMatrix = NeighborLossScaledTrace(test_model, inputs, labels, d, N_train, radiuses[i])
    #print(ProductTraces, Frobeniuses, HessianTraces, TICs)
    return np.mean(ProductTraces), np.mean(Frobeniuses), np.mean(HessianTraces), np.mean(CovarianceTraces), Hessian, CovarianceMatrix

